{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import dataclasses\n",
    "\n",
    "import jax\n",
    "\n",
    "from openpi.models import model as _model\n",
    "from openpi.policies import droid_policy\n",
    "from openpi.policies import policy_config as _policy_config\n",
    "from openpi.shared import download\n",
    "from openpi.training import config as _config\n",
    "from openpi.training import data_loader as _data_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Policy inference\n",
    "\n",
    "The following example shows how to create a policy from a checkpoint and run inference on a dummy example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = _config.get_config(\"pi0_fast_droid\")\n",
    "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_fast_droid\")\n",
    "\n",
    "# Create a trained policy.\n",
    "policy = _policy_config.create_trained_policy(config, checkpoint_dir)\n",
    "\n",
    "# Run inference on a dummy example. This example corresponds to observations produced by the DROID runtime.\n",
    "example = droid_policy.make_droid_example()\n",
    "result = policy.infer(example)\n",
    "\n",
    "# Delete the policy to free up memory.\n",
    "del policy\n",
    "\n",
    "print(\"Actions shape:\", result[\"actions\"].shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Working with a live model\n",
    "\n",
    "\n",
    "The following example shows how to create a live model from a checkpoint and compute training loss. First, we are going to demonstrate how to do it with fake data.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = _config.get_config(\"pi0_aloha_sim\")\n",
    "\n",
    "checkpoint_dir = download.maybe_download(\"s3://openpi-assets/checkpoints/pi0_aloha_sim\")\n",
    "key = jax.random.key(0)\n",
    "\n",
    "# Create a model from the checkpoint.\n",
    "model = config.model.load(_model.restore_params(checkpoint_dir / \"params\"))\n",
    "\n",
    "# We can create fake observations and actions to test the model.\n",
    "obs, act = config.model.fake_obs(), config.model.fake_act()\n",
    "\n",
    "# Sample actions from the model.\n",
    "loss = model.compute_loss(key, obs, act)\n",
    "print(\"Loss shape:\", loss.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we are going to create a data loader and use a real batch of training data to compute the loss."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Reduce the batch size to reduce memory usage.\n",
    "config = dataclasses.replace(config, batch_size=2)\n",
    "\n",
    "# Load a single batch of data. This is the same data that will be used during training.\n",
    "# NOTE: In order to make this example self-contained, we are skipping the normalization step\n",
    "# since it requires the normalization statistics to be generated using `compute_norm_stats`.\n",
    "loader = _data_loader.create_data_loader(config, num_batches=1, skip_norm_stats=True)\n",
    "obs, act = next(iter(loader))\n",
    "\n",
    "# Sample actions from the model.\n",
    "loss = model.compute_loss(key, obs, act)\n",
    "\n",
    "# Delete the model to free up memory.\n",
    "del model\n",
    "\n",
    "print(\"Loss shape:\", loss.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
